# md5 bank 0 c77ec4a018a39945dbb9742f6c37323e
# md5 bank 1 1cd95f427cba732025569be58aaa3aae
#

import string, md5

class HP97():
    _rom = {}
    _labels = {}
    _last_global = ''
    _pc = 0
    _bank = 0
    _ifthen = 0
    _cy = 0
    _del_rom = 0
    _del_rom_emit = 0
    _delta_labels = 0
    _del_rom_force = 0
    _del_rom_force_rom = 0
    _defines = {}
    _cur_define = ''
    _cur_do_line = 1
    
    _pass = 0
    
    _op_arith = ('0 -> a[%s]', '0 -> b[%s]',
                 'a <-> b[%s]', 'a -> b[%s]',
                 'a <-> c[%s]', 'c -> a[%s]',
                 'b -> c[%s]', 'b <-> c[%s]',
                 '0 -> c[%s]', 'a + b -> a[%s]',
                 'a + c -> a[%s]', 'c + c -> c[%s]',
                 'a + c -> c[%s]', 'a + 1 -> a[%s]',
                 'shift left a[%s]', 'c + 1 -> c[%s]',
                 'a - b -> a[%s]', 'a - c -> c[%s]',
                 'a - 1 -> a[%s]', 'c - 1 -> c[%s]',
                 '0 - c -> c[%s]', '0 - c - 1 -> c[%s]',
                 'if b[%s] = 0', 'if c[%s] = 0',
                 'if a >= c[%s]', 'if a >= b[%s]',
                 'if a[%s] # 0', 'if c[%s] # 0',
                 'a - c -> a[%s]', 'shift right a[%s]',
                 'shift right b[%s]', 'shift right c[%s]')
    _op_arith_cy = ( 0, 0, 0, 0, 0, 0, 0, 0,
                     0, 1, 1, 1, 1, 1, 0, 1,
                     1, 1, 1, 1, 1, 1, 0, 0,
                     0, 0, 0, 0, 1, 0, 0, 0 )
    _op_tef = ('p', 'wp', 'xs', 'x', 's', 'm', 'w', 'ms')
    
    _op_misc_0 = ('nop', 'crc motor on?', 'crc ?20?', 'crc test f1',
                  'crc ?40?', 'crc test f2', 'crc ?60?', 'crc test f3',
                  'crc set f4', 'crc test f4', 'crc set f0', 'crc clear f0',
                  'crc set f1', 'crc clear f1', 'crc ?E0?', 'crc wr prot?')

    _op_misc_1 = ('1 -> S0', '1 -> S1', '1 -> S2', '1 -> S3',
                  '1 -> S4', '1 -> S5', '1 -> S6', '1 -> S7',
                  '1 -> S8', '1 -> S9', '1 -> S10', '1 -> S11',
                  '1 -> S12', '1 -> S13', '1 -> S14', '1 -> S15')

    _op_misc_2 = ('clear reg', 'clear S', 'display toggle', 'display off',
                  'm1 <-> c', 'm1 -> c', 'm2 <-> c', 'm2 -> c',
                    'stack -> a', 'down rotate', 'y -> a', 'c -> stack',
                    'decimal', 'unknown D2', 'f -> a', 'f <-> a')

    _op_misc_3 = ('0 -> S0', '0 -> S1', '0 -> S2', '0 -> S3',
                  '0 -> S4', '0 -> S5', '0 -> S6', '0 -> S7',
                  '0 -> S8', '0 -> S9', '0 -> S10', '0 -> S11',
                  '0 -> S12', '0 -> S13', '0 -> S14', '0 -> S15')

    _op_misc_4 = ('keys -> rom addr', 'keys -> a', 'a -> rom addr',
                  'display reset twf', 'binary', 'circulate a left',
                  'p - 1 -> p', 'p + 1 -> p', 'return',
                  'pik home?', 'pik cr?', 'pik keys?', 'pik ?C4?',
                  'pik ?D4?', 'pik ?E4?', 'pik print3')

    _op_misc_5 = ('if S0 = 1', 'if S1 = 1', 'if S2 = 1', 'if S3 = 1',
                  'if S4 = 1', 'if S5 = 1', 'if S6 = 1', 'if S7 = 1',
                  'if S8 = 1', 'if S9 = 1', 'if S10 = 1', 'if S11 = 1',
                  'if S12 = 1', 'if S13 = 1', 'if S14 = 1', 'if S15 = 1')

    _op_misc_6 = ('load constant 0', 'load constant 1', 'load constant 2', 'load constant 3',
                  'load constant 4', 'load constant 5', 'load constant 6', 'load constant 7',
                  'load constant 8', 'load constant 9', 'load constant 10', 'load constant 11',
                  'load constant 12', 'load constant 13', 'load constant 14', 'load constant 15')

    _op_misc_7 = ('if S0 = 0', 'if S1 = 0', 'if S2 = 0', 'if S3 = 0',
                  'if S4 = 0', 'if S5 = 0', 'if S6 = 0', 'if S7 = 0',
                  'if S8 = 0', 'if S9 = 0', 'if S10 = 0', 'if S11 = 0',
                  'if S12 = 0', 'if S13 = 0', 'if S14 = 0', 'if S15 = 0')

    _op_misc_8 = ('sel rom 0', 'sel rom 1', 'sel rom 2', 'sel rom 3',
                  'sel rom 4', 'sel rom 5', 'sel rom 6', 'sel rom 7',
                  'sel rom 8', 'sel rom 9', 'sel rom A', 'sel rom B',
                  'sel rom C', 'sel rom D', 'sel rom E', 'sel rom F')

    _op_misc_9 = ('if p = 4', 'if p = 8', 'if p = 12', 'if p = 2',
                  'if p = 9', 'if p = 1', 'if p = 6', 'if p = 3',
                  'if p = 1', 'if p = 13', 'if p = 5', 'if p = 0',
                  'if p = 11', 'if p = 10', 'if p = 7', 'if p = 4')
    
    _op_misc_A = ('c -> data r0', 'c -> data r1', 'c -> data r2', 'c -> data r3',
                  'c -> data r4', 'c -> data r5', 'c -> data r6', 'c -> data r7',
                  'c -> data r8', 'c -> data r9', 'c -> data rA', 'c -> data rB',
                  'c -> data rC', 'c -> data rD', 'c -> data rE', 'c -> data rF')

    _op_misc_B = ('if p # 4', 'if p # 8', 'if p # 12', 'if p # 2',
                  'if p # 9', 'if p # 1', 'if p # 6', 'if p # 3',
                  'if p # 1', 'if p # 13', 'if p # 5', 'if p # 0',
                  'if p # 11', 'if p # 10', 'if p # 7', 'if p # 4')

    _op_misc_C = ('crc ?0C?', 'crc ?1C?', 'crc motor on', 'crc motor off',
                  'crc ?4C?', 'crc card in?', 'crc test prot', 'crc ?7C?',
                  'bank switch', 'c -> addr', 'clear data regs', 'c -> data',
                  'rom selftest', 'crc ?DC?', 'pik print6', "hi i'm woodstock")

    _op_misc_D = ('del sel rom 0', 'del sel rom 1', 'del sel rom 2', 'del sel rom 3',
                  'del sel rom 4', 'del sel rom 5', 'del sel rom 6', 'del sel rom 7',
                  'del sel rom 8', 'del sel rom 9', 'del sel rom A', 'del sel rom B',
                  'del sel rom C', 'del sel rom D', 'del sel rom E', 'del sel rom F')

    _op_misc_E = ('data -> c' , 'data r1 -> c', 'data r2 -> c', 'data r3 -> c',
                  'data r4 -> c', 'data r5 -> c', 'data r6 -> c', 'data r7 -> c',
                  'data r8 -> c', 'data r9 -> c', 'data rA -> c', 'data rB -> c',
                  'data rC -> c', 'data rD -> c', 'data rE -> c', 'data rF -> c')

    _op_misc_F = ('14 -> p', '4 -> p', '7 -> p', '8 -> p',
                  '11 -> p', '2 -> p', '10 -> p', '12 -> p',
                  '1 -> p', '3 -> p', '13 -> p', '6 -> p',
                  '0 -> p', '9 -> p', '5 -> p', '14 -> p')

    _op_branch = ('then go to', 'if n/c go to', 'go to', 'jsb')

    def _match(self, l, ops):
        found = -1
        length = 0
        j = 0
        for op in ops:
            o = string.split(op)
            o_l = len(o)
            l_l = len(l)
            if (o_l <= l_l):
                nok = 0
                for i in range(o_l):
                    if (o[i] != l[i]):
                        nok = 1
                        break
                if (nok == 0):
                    length = o_l
                    found = j
            j = j + 1
            if (found >= 0):
                break
        return (found, length,)
    
    def _calc_code(self, col, line, klass):
        return col << 6 | line << 2 | klass
    
    def _find_misc(self, ll):
        found, length = self._match(ll, self._op_misc_0)
        if (found >= 0):
            return (self._calc_code(found, 0, 0), length,)
        found, length = self._match(ll, self._op_misc_1)
        if (found >= 0):
            return (self._calc_code(found, 1, 0), length,)
        found, length = self._match(ll, self._op_misc_2)
        if (found >= 0):
            return (self._calc_code(found, 2, 0), length,)
        found, length = self._match(ll, self._op_misc_3)
        if (found >= 0):
            return (self._calc_code(found, 3, 0), length,)
        found, length = self._match(ll, self._op_misc_4)
        if (found >= 0):
            return (self._calc_code(found, 4, 0), length,)
        found, length = self._match(ll, self._op_misc_5)
        if (found >= 0):
            self._ifthen = 1
            return (self._calc_code(found, 5, 0), length,)
        found, length = self._match(ll, self._op_misc_6)
        if (found >= 0):
            return (self._calc_code(found, 6, 0), length,)
        found, length = self._match(ll, self._op_misc_7)
        if (found >= 0):
            self._ifthen = 1
            return (self._calc_code(found, 7, 0), length,)
        found, length = self._match(ll, self._op_misc_8)
        if (found >= 0):
            return (self._calc_code(found, 8, 0), length,)
        found, length = self._match(ll, self._op_misc_9)
        if (found >= 0):
            self._ifthen = 1
            return (self._calc_code(found, 9, 0), length,)
        found, length = self._match(ll, self._op_misc_A)
        if (found >= 0):
            return (self._calc_code(found, 10, 0), length,)
        found, length = self._match(ll, self._op_misc_B)
        if (found >= 0):
            self._ifthen = 1
            return (self._calc_code(found, 11, 0), length,)
        found, length = self._match(ll, self._op_misc_C)
        if (found >= 0):
            return (self._calc_code(found, 12, 0), length,)
        found, length = self._match(ll, self._op_misc_D)
        if (found >= 0):
            return (self._calc_code(found, 13, 0), length,)
        found, length = self._match(ll, self._op_misc_E)
        if (found >= 0):
            return (self._calc_code(found, 14, 0), length,)
        found, length = self._match(ll, self._op_misc_F)
        if (found >= 0):
            return (self._calc_code(found, 15, 0), length,)
        return (-1, 0,)

    def _find_arith(self, ll):
        k = 0
        for tef in self._op_tef:
            arith = []
            for op in self._op_arith:
                arith.append(op % tef)
            #print arith
            found, length = self._match(ll, arith)
            if (found >= 0):
                self._cy = self._op_arith_cy[found]
                if ((found >= 22) and (found <= 27)):
                    self._ifthen = 1
                return (found << 5 | k << 2 | 0x002, length,)
            k = k + 1
        return (-1, 0,)
        
    def _find_opcode(self, ll, passe, last):
        found, length = self._match(ll, self._op_branch)
        if (found >= 0):
            if (found == 0):            # then go to
                if (self._ifthen == 0):
                    raise 'Error, then go to without if'
                else:
                    self._ifthen = 0
                    if (passe == 0):
                        code = 0x000
                        return (code, length + 1,)
                    if (passe == 1):
                        adr = self._find_label(ll[length])
                        if (last):
                            if (adr < 0):
                                print ll, 'label not found'
                                raise 'Label not found'
                        code = adr - (self._pc & 0xC00)
                        if ((code < 0) or (code > 1023)):
                            if (last):
                                print ll[length], adr, code, self._pc, self._bank
                                raise 'Error, then go to too far'
                        return (code, length + 1,)
            elif (found == 1):          # if n/c then go to
                if (self._cy == 0):
                    raise 'Error, if n/c then go to without cy operation'
                else:
                    self._cy = 0
                    if (passe == 0):
                        code = 0x003
                        return (code, length + 1,)
                    if (passe == 1):
                        adr = self._find_label(ll[length])
                        if (last):
                            if (adr < 0):
                                print ll, 'label not found'
                                raise 'Label not found'
                        dist = adr - (self._pc & 0xF00)
                        if ((dist < 0) or (dist > 255)):
                            if (last):
                                print ll
                                raise 'Error, if n/c then go to too far'
                        code = dist << 2 | 0x003
                        return (code, length + 1,)
            elif ((found == 2) or (found == 3)):          # go to or jsb
                    self._cy = 0
                    if (passe == 0):
                        adr = self._find_label(ll[length])
                        if (adr >= 0):
                            dist = adr - (self._pc & 0xF00)
                            if ((dist < 0) or (dist > 255)):
                                if (self._del_rom_force):
                                    pass
                                else:
                                    self._del_rom_emit = 1
                                    self._del_rom = (adr >> 8) << 6 | 0x034
                                dist = adr & 0x0FF
                            if (found == 3):
                                code = dist << 2 | 0x001
                            else:
                                code = dist << 2 | 0x003
                        else:
                            if (found == 3):
                                code = 0x001
                            else:
                                code = 0x003
                        self._del_rom_force = 0
                        return (code, length + 1,)
                    if (passe == 1):
                        adr = self._find_label(ll[length])
                        if (last):
                            if (adr < 0):
                                print ll, 'label not found'
                                raise 'Label not found'
                        dist = adr - (self._pc & 0xF00)
                        #print ll[length], adr, dist
                        if ((dist < 0) or (dist > 255)):
                            if (self._del_rom_force):
                                if (last):
                                    if ((adr >> 8) != self._del_rom_force_rom):
                                        raise 'manual del sel rom not on target'
                            else:
                                self._del_rom_emit = 1
                                self._del_rom = (adr >> 8 ) << 6 | 0x034
                            dist = adr & 0x0FF
                        if (found == 3):
                            code = dist << 2 | 0x001
                        else:
                            code = dist << 2 | 0x003
                        self._del_rom_force = 0
                        return (code, length + 1,)
        else:     
            self._del_rom_force = 0
            code, length = self._find_misc(ll)
            if (code >= 0):
                #if (last):
                #    print '%03X %s' % (code, ll,)
                if (code == 0x230):                 # bank switch
                    if (last):
                        if (self._find_label(ll[length]) != (self._pc + 1)):
                            pass
                            #print '%s %04X %04X' % (ll[length], self._find_label(ll[length]), self._pc+1)
                            raise 'bank switch not on target'
                    length = length + 1
                if ((code & 0x03F) == 0x020):       # sel rom
                    if (last):
                        dest = ((code >> 6) << 8) | (self._pc & 0x0FF) + 1
                        if (self._find_label(ll[length]) != dest):
                            #print '%s %04X %03X %04X %04X' % (ll[length], self._pc, code, self._find_label(ll[length]), dest)
                            raise 'sel rom not on target'
                    length = length + 1
                if ((code & 0x03F) == 0x034):       # del sel rom
                    self._del_rom_force = 1
                    self._del_rom_force_rom = (code >> 6)
                return (code, length,)
            code, length = self._find_arith(ll)
            if (code >= 0):
                return (code, length,)
            if ll[0] == 'org':                      # org 0xXXX
                if (last):
                    if (self._pc > int(ll[1], 0)):
                        print '%03X > %03X' % (self._pc, int(ll[1], 0), )
                        raise 'org error, base > current pc'
                self._pc = int(ll[1], 0)
                return (-1, 2,)
            elif ll[0] == 'bank':                   # bank 0 1
                self._bank = (ll[1] != '0')     
                return (-1, 2,)
        return (-1, 0,)    

    def _add_label(self, name, address):
        if (name[0] != '.'):
            self._last_global = name
        else:
            name = self._last_global + name
        if name in self._labels.keys():
            print name
            raise 'Error, label already defined'
        self._labels[name] = address | (self._bank << 12)
    
    def _correct_label(self, name, address):
        if (name[0] != '.'):
            self._last_global = name
        else:
            name = self._last_global + name
        if name not in self._labels.keys():
            print name
            raise 'Error, label not defined'
        if (self._labels[name] != (address | (self._bank << 12))):
            self._delta_labels = 1
            self._labels[name] = address | (self._bank << 12)

    def _find_label(self, name):
        if (name[0] == '.'):
            name = self._last_global + name
        if name in self._labels.keys():
            return self._labels[name] & 0xFFF
        else:
            return -1

    def _add_define(self, name, value):
        if name in self._defines.keys():
            print name
            raise 'Error, define already defined'
        self._defines[name] = value

    def _find_define(self, name):
        if name in self._defines.keys():
            return self._defines[name]
        else:
            return 0

    def assemble(self, file_in, file_out, file_out0, file_out1, display=0):
        f = open(file_in, 'rt')
        lines = f.readlines()
        f.close()

        self._last_global = ''
        self._pc = 0
        self._bank = 0
        self._ifthen = 0
        self._cy = 0
        self._del_rom = 0
        self._del_rom_emit = 0
        self._delta_labels = 0
        self._del_rom_force = 0
        self._del_sel_force_rom = 0
        self._defines = {}
        self._rom = 8192 * [0]
        self._cur_define = ''
        self._cur_do_line = 1

        define = 0
        
        # pass 0
        print 'pass 0'
        for line in lines:
            #print line
            ll = string.split(line)
            if (len(ll) > 0):
                if (line[0] > ' '):
                    if ll[0] == '#define':
                        self._add_define(ll[1], int(ll[2], 0))
                        define = 1
                    elif ll[0] == '#ifdef':
                        self._cur_define = ll[1]
                        if (self._find_define(ll[1])):
                            self._cur_do_line = 1
                        else:
                            self._cur_do_line = 0
                        define = 1
                    elif ll[0] == '#else':
                        self._cur_do_line = 1 - self._cur_do_line;
                        define = 1
                    elif ll[0] == '#endif':
                        self._cur_do_line = 1
                        self._cur_define = ''
                        define = 1
                    else:
                        if (self._cur_do_line):
                            label = ll[0]
                            ll = ll[1:]
                            self._add_label(label[:-1], self._pc)
                        define = 0
                else:
                    define = 0
                    label = ''
                if (define):
                    pass
                    #print '%X%03X %s' % (self._bank, self._pc, string.join(ll[0:]))
                elif (self._cur_do_line):
                    label = label + 20*' '
                    label = label[0:20]
                    code = -1
                    length = len(ll)
                    if length > 0:
                        code, length = self._find_opcode(ll, 0, 0)
                    if (code < 0):
                        length = len(ll)
                    if (self._del_rom_emit):
                        #print '%X%03X %s %03X %03X    %s' % (self._bank, self._pc, label, self._del_rom, code, string.join(ll[0:length]))
                        self._pc = self._pc + 2
                        self._del_rom_emit = 0
                    elif (code >= 0):
                        #print '%X%03X %s %03X     %s' % (self._bank, self._pc, label, code, string.join(ll[0:length]))
                        self._pc = self._pc + 1
                    else:
                        pass
                        #print '%X%03X %s         %s' % (self._bank, self._pc, label, string.join(ll[0:length]))

        # pass 1 and 2 and ...

        finished = 0
        last = 0
        self._pass = 0
        while(finished == 0):
            self._last_global = ''
            self._pc = 0
            self._bank = 0
            self._ifthen = 0
            self._cy = 0
            self._del_rom = 0
            self._del_rom_emit = 0
            self._delta_labels = 0
            self._del_rom_force = 0
            self._del_sel_force_rom = 0
            self._defines = {}
            self._cur_define = ''
            self._cur_do_line = 1
           
            self._pass = self._pass + 1
            print 'pass %d' % self._pass

            for line in lines:
                # print line
                ll = string.split(line)
                if (len(ll) > 0):
                    if (line[0] > ' '):
                        if ll[0] == '#define':
                            self._add_define(ll[1], int(ll[2], 0))
                            define = 1
                        elif ll[0] == '#ifdef':
                            self._cur_define = ll[1]
                            if (self._find_define(ll[1])):
                                self._cur_do_line = 1
                            else:
                                self._cur_do_line = 0
                            define = 1
                        elif ll[0] == '#else':
                            self._cur_do_line = 1 - self._cur_do_line;
                            define = 1
                        elif ll[0] == '#endif':
                            self._cur_do_line = 1
                            self._cur_define = ''
                            define = 1
                        else:
                            if (self._cur_do_line):
                                label = ll[0]
                                ll = ll[1:]
                                self._correct_label(label[:-1], self._pc)
                            define = 0
                    else:
                        define = 0
                        label = ''
                    if (define):
                        if (last):
                            if (display):
                                print '%X%03X %s' % (self._bank, self._pc, string.join(ll[0:]))
                            h.write('%X%03X %s\n' % (self._bank, self._pc, string.join(ll[0:])))
                    elif (self._cur_do_line):
                        label = label + 20*' '
                        label = label[0:20]
                        code = -1
                        length = len(ll)
                        com = ''
                        if length > 0:
                            code, length = self._find_opcode(ll, 1, last)
                            if (len(ll) > length):
                                com = string.join(ll[length:])
                            else:
                                com = ''
                            if (length > 0):
                                opcode = string.join(ll[0:length])
                                if (len(opcode) < 30):
                                    opcode = opcode + ' '*(30 - len(opcode))
                                else:
                                    opcode = opcode[:27] + '...'
                            else:
                                opcode = ''
                        else:
                            opcode = ''
                        if (self._del_rom_emit):
                            if (last):
                                self._rom[self._pc | (self._bank << 12)] = self._del_rom
                                self._rom[self._pc + 1 | (self._bank << 12)] = code
                                if (display):
                                    print '%X%03X %s %03X %03X %30s %s' % (self._bank, self._pc, label, self._del_rom, code, opcode, com)
                                h.write('%X%03X %s %03X %03X %30s %s\n' %
                                        (self._bank, self._pc, label, self._del_rom, code, opcode, com))
                            self._pc = self._pc + 2
                            self._del_rom_emit = 0
                        elif (code >= 0):
                            if (last):
                                self._rom[self._pc | (self._bank << 12)] = code
                                if (display):
                                    print '%X%03X %s %03X     %s %s' % (self._bank, self._pc, label, code, opcode, com)
                                h.write('%X%03X %s %03X     %s %s\n' %
                                        (self._bank, self._pc, label, code, opcode, com))
                            self._pc = self._pc + 1
                        else:
                            if (last):
                                if (opcode == ''):
                                    if (display):
                                        print '%X%03X %s         %s' % (self._bank, self._pc, label, com)
                                    h.write('%X%03X %s         %s\n' %
                                            (self._bank, self._pc, label, com))
                                else:
                                    if (display):
                                        print '%X%03X %s         %s %s' % (self._bank, self._pc, label, opcode, com)
                                    h.write('%X%03X %s         %s %s\n' %
                                            (self._bank, self._pc, label, opcode, com))
                        self._pc = self._pc & 0xFFF
            if (self._delta_labels == 0):
                if (last == 0):
                    h = open(file_out, 'wt')
                    last = 1
                else:
                    finished = 1

        h.close()

        m0 = md5.new()
        m1 = md5.new()
        
        f = open(file_out0, 'wb')
        g = open(file_out1, 'wb')
        for i in range(4096):
            a = self._rom[i]
            f.write(chr(a & 0xFF))
            f.write(chr(a >> 8))
            m0.update(chr(a & 0xFF))
            m0.update(chr(a >> 8))
            b = self._rom[i+4096]
            g.write(chr(b & 0xFF))
            g.write(chr(b >> 8))
            m1.update(chr(b & 0xFF))
            m1.update(chr(b >> 8))
        f.close()
        g.close()

        print 'ori: c77ec4a018a39945dbb9742f6c37323e'
        print 'new:', m0.hexdigest()
        print 'ori: 1cd95f427cba732025569be58aaa3aae'
        print 'new:', m1.hexdigest()

        
topcat = HP97()
topcat.assemble('97.src.lst', '97.src-wk.out', '97p_bank0.bin', '97p_bank1.bin', 0)


